# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from dataclasses import dataclass

# if TYPE_CHECKING:
import torch
import ttnn

from .utils import from_torch_fast_2d


@dataclass
class TtLinearParameters:
    weight: ttnn.Tensor
    bias: ttnn.Tensor | None

    @classmethod
    def from_torch(
        cls,
        state: dict[str, torch.Tensor],
        *,
        dtype: ttnn.DataType | None = None,
        device: ttnn.Device,
        shard_dim: int = None,
        unsqueeze_bias: bool = False,
        hidden_dim_padding: int,
        parallel_config: DiTParallelConfig,
    ) -> TtLinearParameters:
        if "bias" in state:
            bias = state["bias"].unsqueeze(0)
        else:
            bias = None
        weight = state["weight"]
        if hidden_dim_padding > 0:
            hidden_dim = 2432
            hidden_dim_new = hidden_dim + hidden_dim_padding
            weight_h, weight_w = weight.shape
            weight_h_mult = weight_h // hidden_dim
            weight_w_mult = weight_w // hidden_dim
            if weight_h % hidden_dim == 0:
                if weight_h_mult == 1:
                    weight = torch.nn.functional.pad(
                        weight, pad=(0, 0, 0, hidden_dim_padding), mode="constant", value=0
                    )
                elif weight_h_mult > 1:
                    weight = weight.reshape(weight_h_mult, hidden_dim, weight_w)
                    weight = torch.nn.functional.pad(
                        weight, pad=(0, 0, 0, hidden_dim_padding), mode="constant", value=0
                    )
                    weight = weight.reshape(weight_h_mult * hidden_dim_new, weight_w)
            weight_h, weight_w = weight.shape
            if weight_w % hidden_dim == 0:
                if weight_w_mult == 1:
                    weight = torch.nn.functional.pad(weight, pad=(0, hidden_dim_padding), mode="constant", value=0)
                elif weight_w_mult > 1:
                    weight = weight.reshape(weight_h, weight_w_mult, hidden_dim)
                    weight = torch.nn.functional.pad(weight, pad=(0, hidden_dim_padding), mode="constant", value=0)
                    weight = weight.reshape(weight_h, weight_w_mult * hidden_dim_new)

            if not bias == None:
                bias_h, bias_w = bias.shape
                bias_w_mult = bias_w // hidden_dim
                if bias_w % hidden_dim == 0:
                    if bias_w_mult == 1:
                        bias = torch.nn.functional.pad(bias, pad=(0, hidden_dim_padding), mode="constant", value=0)
                    elif bias_w_mult > 1:
                        bias = bias.reshape(bias_h, bias_w_mult, hidden_dim)
                        bias = torch.nn.functional.pad(bias, pad=(0, hidden_dim_padding), mode="constant", value=0)
                        bias = bias.reshape(bias_h, bias_w_mult * hidden_dim_new)

        if unsqueeze_bias:
            # TODO: Once the issue is resolved, remove this workaround for https://github.com/tenstorrent/tt-metal/issues/16599
            bias = bias.unsqueeze(0)

        if shard_dim in [0, -2]:
            # Shard the bias of a linear operation on the first dimension.
            # A single device receive the bias as is, while the other ones receive zero tensors of the same
            # shape so that the bias is not added multiple times after gathering.
            tp_factor = parallel_config.tensor_parallel.factor
            zeros = torch.zeros_like(bias)
            bias = torch.cat([bias] + [zeros] * (tp_factor - 1), dim=0)
            bias_dims = [None, None]
            bias_dims[parallel_config.tensor_parallel.mesh_axis] = 0
            bias_mm = ttnn.ShardTensor2dMesh(device, mesh_shape=tuple(device.shape), dims=bias_dims)
        elif shard_dim in [1, -1]:
            bias_dims = [None, None]
            bias_dims[parallel_config.tensor_parallel.mesh_axis] = shard_dim
            bias_mm = ttnn.ShardTensor2dMesh(device, mesh_shape=tuple(device.shape), dims=bias_dims)
        else:
            bias_mm = ttnn.ShardTensor2dMesh(device, mesh_shape=tuple(device.shape), dims=[None, None])

        dims = [None, None]
        dims[parallel_config.tensor_parallel.mesh_axis] = shard_dim
        return cls(
            weight=from_torch_fast_2d(
                weight.transpose(0, 1),
                mesh_device=device,
                mesh_shape=tuple(device.shape),
                dims=dims,
                layout=ttnn.TILE_LAYOUT,
                dtype=dtype,
            ),
            bias=from_torch_fast_2d(
                bias,
                mesh_device=device,
                mesh_shape=tuple(device.shape),
                dims=dims,
                dtype=dtype,
                layout=ttnn.TILE_LAYOUT,
                mesh_mapper=bias_mm,
            )
            if bias is not None
            else None,
        )

    @classmethod
    def from_torch_col_parallel(
        cls,
        state: dict[str, torch.Tensor],
        *,
        n_local_heads: int,
        unpadded_num_heads: int,
        hidden_dim_padding: int,
        dtype: ttnn.DataType | None = None,
        device: ttnn.Device,
        parallel_config: DiTParallelConfig,
    ) -> TtLinearParameters:
        if "bias" in state:
            torch_bias = state["bias"].unsqueeze(0)
        else:
            torch_bias = None

        weight = state["weight"]
        torch_weight = weight.transpose(0, 1)
        if unpadded_num_heads < (n_local_heads * parallel_config.tensor_parallel.factor):
            head_size = torch_weight.shape[1] // 3 // unpadded_num_heads
            head_padding = parallel_config.tensor_parallel.factor - (
                unpadded_num_heads % parallel_config.tensor_parallel.factor
            )
            weight_h, weight_w = torch_weight.shape

            torch_weight = torch_weight.reshape(weight_h, 3, unpadded_num_heads, head_size)
            torch_weight = torch.nn.functional.pad(
                torch_weight, pad=(0, 0, 0, head_padding, 0, 0, 0, hidden_dim_padding), mode="constant", value=0
            )
            torch_weight = torch_weight.reshape(weight_h + hidden_dim_padding, -1)
            if not torch_bias == None:
                bias_h, bias_w = torch_bias.shape
                torch_bias = torch_bias.reshape(bias_h, 3, unpadded_num_heads, head_size)
                torch_bias = torch.nn.functional.pad(torch_bias, pad=(0, 0, 0, head_padding), mode="constant", value=0)
                torch_bias = torch_bias.reshape(bias_h, -1)

        def shuffle_heads(tensor):
            # Given torch tensor with output features in the last dimension,
            # shuffle heads to allow for column parallel computation
            in_dim = tensor.shape[0]
            tensor = tensor.reshape(
                in_dim, 3, parallel_config.tensor_parallel.factor, n_local_heads, -1
            )  # [ID, 3, ND, NLH, DH]
            tensor = tensor.permute(0, 2, 1, 3, 4)  # [ID, ND, 3, NLH, DH]
            tensor = tensor.reshape(in_dim, -1)  # [ID, ND*3*NLH*DH]
            return tensor

        dims = [None, None]
        dims[parallel_config.tensor_parallel.mesh_axis] = -1
        return cls(
            weight=from_torch_fast_2d(
                shuffle_heads(torch_weight),
                mesh_device=device,
                mesh_shape=tuple(device.shape),
                dims=dims,
                layout=ttnn.TILE_LAYOUT,
                dtype=dtype,
            ),
            bias=from_torch_fast_2d(
                shuffle_heads(torch_bias),
                mesh_device=device,
                mesh_shape=tuple(device.shape),
                dims=dims,
                layout=ttnn.TILE_LAYOUT,
                dtype=dtype,
            )
            if torch_bias is not None
            else None,
        )

    @classmethod
    def from_torch_time_embed(
        cls,
        state: dict[str, torch.Tensor],
        *,
        num_chunks: int,
        hidden_dim_padding: int,
        dtype: ttnn.DataType | None = None,
        device: ttnn.Device,
        unsqueeze_bias: bool = False,
        parallel_config: DiTParallelConfig,
    ) -> TtLinearParameters:
        if "bias" in state:
            torch_bias = state["bias"].unsqueeze(0)
        else:
            torch_bias = None
        weight = state["weight"]
        torch_weight = weight.transpose(0, 1)

        if hidden_dim_padding:
            weight_h, weight_w = torch_weight.shape
            torch_weight = torch_weight.reshape(weight_h, num_chunks, -1)
            torch_weight = torch.nn.functional.pad(
                torch_weight, pad=(0, hidden_dim_padding, 0, 0, 0, hidden_dim_padding), mode="constant", value=0
            )
            torch_weight = torch_weight.reshape(weight_h + hidden_dim_padding, -1)

            if not torch_bias == None:
                bias_h, bias_w = torch_bias.shape
                torch_bias = torch_bias.reshape(bias_h, num_chunks, -1)
                torch_bias = torch.nn.functional.pad(torch_bias, pad=(0, hidden_dim_padding), mode="constant", value=0)
                torch_bias = torch_bias.reshape(bias_h, -1)

        def shuffle_chunks(tensor):
            # Given torch tensor with output features in the last dimension,
            # shuffle heads to allow for column parallel computation
            in_dim = tensor.shape[0]
            tensor = tensor.reshape(in_dim, num_chunks, parallel_config.tensor_parallel.factor, -1)
            tensor = tensor.permute(0, 2, 1, 3)
            tensor = tensor.reshape(in_dim, -1)
            return tensor

        torch_weight = shuffle_chunks(torch_weight)
        torch_bias = shuffle_chunks(torch_bias)

        if unsqueeze_bias:
            # TODO: Once the issue is resolved, remove this workaround for https://github.com/tenstorrent/tt-metal/issues/16599
            torch_bias = torch_bias.unsqueeze(0)

        dims = [None, None]
        dims[parallel_config.tensor_parallel.mesh_axis] = -1
        return cls(
            weight=from_torch_fast_2d(
                torch_weight,
                mesh_device=device,
                mesh_shape=tuple(device.shape),
                dims=dims,
                layout=ttnn.TILE_LAYOUT,
                dtype=dtype,
            ),
            bias=from_torch_fast_2d(
                torch_bias,
                mesh_device=device,
                mesh_shape=tuple(device.shape),
                dims=dims,
                layout=ttnn.TILE_LAYOUT,
                dtype=dtype,
            )
            if torch_bias is not None
            else None,
        )

    @property
    def in_channels(self) -> int:
        return self.weight.shape[0]

    @property
    def out_channels(self) -> int:
        return self.weight.shape[1]

    @classmethod
    def from_torch_sharding_projection(
        cls,
        state: dict[str, torch.Tensor],
        *,
        dtype: ttnn.DataType | None = None,
        device: ttnn.Device,
        shard_dim: int,
    ) -> TtLinearParameters:
        weight = state["weight"]

        return cls(
            weight=from_torch_fast_2d(
                weight,
                mesh_device=device,
                mesh_shape=tuple(device.shape),
                dims=[shard_dim, None],
                layout=ttnn.TILE_LAYOUT,
                dtype=dtype,
            ),
            bias=None,
        )


def sd_linear(
    x: ttnn.Tensor,
    parameters: TtLinearParameters,
    *,
    memory_config: ttnn.MemoryConfig | None = None,
    program_config: ttnn.MatmulProgramConfig | None = None,
    compute_kernel_config: ttnn.ComputeKernelConfig | None = None,
    core_grid: ttnn.CoreGrid | None = None,
    output_tile: list[int] | None = None,
    dtype: ttnn.DataType | None = None,
    activation: str | None = None,
    deallocate: bool = False,
    prob: bool = False,
    transpose_a: bool = False,
) -> ttnn.Tensor:
    msg = f"last value in input shape {list(x.shape)} should be equal to {parameters.in_channels}"
    if not transpose_a:
        assert x.shape[-1] == parameters.in_channels, msg
    else:
        assert x.shape[-2] == parameters.in_channels, msg

    weight = parameters.weight
    bias = parameters.bias

    # there is a correctness issue with tensors of shape Mx1x1xN, squeeze them to Mx1xN
    squeeze = len(x.shape) == 4 and x.shape[1] == 1 and x.shape[2] == 1
    if squeeze:
        x = x.reshape([x.shape[0], 1, x.shape[-1]])

    assert x.device() == weight.device()
    if bias is not None:
        assert x.device() == bias.device()

    output = ttnn.linear(
        x,
        weight,
        bias=bias,
        memory_config=memory_config,
        program_config=program_config,
        compute_kernel_config=compute_kernel_config,
        core_grid=core_grid,
        output_tile=output_tile,
        dtype=dtype,
        activation=activation,
        transpose_a=transpose_a,
    )

    # if deallocate:
    #     ttnn.deallocate(x)

    if squeeze:
        output = output.reshape([output.shape[0], 1, 1, output.shape[-1]])

    return output
